package org.leon.filters;

import org.springframework.beans.factory.annotation.Value;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

@Component
public class IpLimitFilter implements GlobalFilter, Ordered {

    private ConcurrentHashMap<String, IpInfo> ipMap = new ConcurrentHashMap<>();
    private DateTimeFormatter formatter = DateTimeFormatter.ofPattern("YYYYMMddHHmm");

    // 限制时间窗: 单位分钟
    @Value("${lagou-service.gateway.ip-limit.duration}")
    private Integer ipLimitDuration;

    @Value("${lagou-service.gateway.ip-limit.limit}")
    private Integer ipLimitCount;

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        ServerHttpRequest request = exchange.getRequest();
        ServerHttpResponse response = exchange.getResponse();
        String clientIp = request.getRemoteAddress().getHostString();

        if (checkIpAccessCount(clientIp)) {
            response.setStatusCode(HttpStatus.SEE_OTHER); // 状态码
            String data = "您请求过于频繁，请稍后再试";
            DataBuffer wrap = response.bufferFactory().wrap(data.getBytes());
            return response.writeWith(Mono.just(wrap));
        }

        return chain.filter(exchange);
    }

    @Override
    public int getOrder() {
        return 0;
    }

    /**
     * 检查是否触发了单位时间内 IP 访问次数限制
     * @return 触发限制返回 true; 否则返回 false
     */
    private boolean checkIpAccessCount(String clientIp) {
        if (ipMap.containsKey(clientIp)) {
            IpInfo ipInfo = ipMap.get(clientIp);
            TreeMap<String, AtomicInteger> accessCountPerMin = ipInfo.getAccessCountPerMin();

            LocalDateTime now = LocalDateTime.now();
            String currentMinute = formatter.format(now);
            if (accessCountPerMin.containsKey(currentMinute)) {
                accessCountPerMin.get(currentMinute).incrementAndGet();
            } else {
                accessCountPerMin.put(currentMinute, new AtomicInteger(1));
            }

            LocalDateTime timeWindowStart = now.minusMinutes(ipLimitDuration);
            String timeWindowStartMinuteStr = formatter.format(timeWindowStart);

            // 包含首位的子集合
            SortedMap<String, AtomicInteger> subMap = accessCountPerMin.subMap(timeWindowStartMinuteStr, true, currentMinute, true);
            int totalAccessCount = subMap.values().stream().mapToInt(AtomicInteger::get).sum();
            if (totalAccessCount >= ipLimitCount) {
                return true;
            }


            /*
            // 计算以当前时间为结束的时间窗开始时间点
            LocalDateTime now = LocalDateTime.now();
            String currentMinute = formatter.format(now);
            LocalDateTime timeWindowStart = now.minusMinutes(ipLimitDuration);
            String timeWindowStartMinuteStr = formatter.format(timeWindowStart);

            // 包含首位的子集合
            SortedMap<String, AtomicInteger> subMap = accessCountPerMin.subMap(timeWindowStartMinuteStr, true, currentMinute, true);

            // 判断本次访问是否在子集合里面
            if (subMap.containsKey(currentMinute)) {
                subMap.get(currentMinute).incrementAndGet();
            } else {
                subMap.put(currentMinute, new AtomicInteger(1));
            }

            // 更新当前 IP 的时间窗访问次数
            ipInfo.setAccessCountPerMin(subMap);


            // 判断当前时间窗内的访问次数是否超过限制
            int totalAccessCount = subMap.values().stream().mapToInt(AtomicInteger::get).sum();
            if (totalAccessCount >= ipLimitCount) {
                return true;
            }*/
        } else {
            IpInfo ipInfo = new IpInfo(clientIp);
            ipMap.put(clientIp, ipInfo);
        }
        return false;
    }

    class IpInfo {
        private String ip;

        // 记录每分钟的访问次数, 其中 key 的格式为 YYYYMMddHHmm 如 202005011210
        private TreeMap<String, AtomicInteger> accessCountPerMin;

        public IpInfo(String ip) {
            this.ip = ip;
            this.accessCountPerMin = new TreeMap<>();
            String currentTimeInMinute = formatter.format(LocalDateTime.now());
            accessCountPerMin.put(currentTimeInMinute, new AtomicInteger(1));
        }

        public String getIp() {
            return ip;
        }

        public void setIp(String ip) {
            this.ip = ip;
        }

        public TreeMap<String, AtomicInteger> getAccessCountPerMin() {
            return accessCountPerMin;
        }

        public void setAccessCountPerMin(SortedMap<String, AtomicInteger> subMap) {
            this.accessCountPerMin.clear();
            for (Map.Entry<String, AtomicInteger> entry : subMap.entrySet()) {
                this.accessCountPerMin.put(entry.getKey(), entry.getValue());
            }
        }


    }
}
